In [1]:
import pandas as pd
import numpy as np
import altair as alt

import theme
from natsort import natsorted, natsort_keygen

alt.data_transformers.disable_max_rows()
Out[1]:
DataTransformerRegistry.enable('default')
In [2]:
mutation_effects = pd.read_csv('../results/combined_effects/combined_mutation_effects.csv')
mutation_effects.head()
Out[2]:
mutant struct_site h3_wt_aa h5_wt_aa h7_wt_aa rmsd_h3h5 rmsd_h3h7 rmsd_h5h7 4o5n_aa_RSA 4kwm_aa_RSA 6ii9_aa_RSA h3_effect h3_effect_std h5_effect h5_effect_std h7_effect h7_effect_std
0 A 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN 0.0151 0.7225 0.2049 0.2627 NaN NaN
1 C 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN -0.4080 0.3850 -0.3977 0.1072 NaN NaN
2 D 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN 0.2361 0.2740 0.2383 0.2087 NaN NaN
3 E 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN -0.2463 0.8478 0.3120 0.2815 NaN NaN
4 F 9 S K NaN 9.1674 NaN NaN 1.084277 1.140252 NaN 0.2061 0.3214 -0.8917 1.2020 NaN NaN
In [3]:
site_effects = pd.read_csv('../results/combined_effects/combined_site_effects.csv')
site_effects.head()
Out[3]:
struct_site h3_wt_aa h5_wt_aa h7_wt_aa rmsd_h3h5 rmsd_h3h7 rmsd_h5h7 4o5n_aa_RSA 4kwm_aa_RSA 6ii9_aa_RSA avg_h3_effect avg_h5_effect avg_h7_effect
0 9 S K NaN 9.167400 NaN NaN 1.084277 1.140252 NaN -0.050776 -0.998095 NaN
1 10 T S NaN 8.157247 NaN NaN 0.150962 0.175962 NaN -0.697911 -3.348267 NaN
2 11 A D D 5.040040 2.984626 2.886615 0.050388 0.097927 0.624352 -3.138280 -3.951383 -2.962194
3 12 T Q K 3.937602 1.626754 3.384350 0.268605 0.216889 0.368644 -1.036219 -0.342761 -1.705403
4 13 L I I 3.687798 1.734039 2.549524 0.000000 0.000000 0.000000 -3.941050 -3.827571 -3.829644
In [4]:
# Read in protein sequence identities
seq_identity = pd.read_csv('../results/sequence_identity/ha_sequence_identity.csv')
seq_identity.head()
Out[4]:
ha_x ha_y matches alignable_residues percent_identity
0 H3 H5 192.0 479.0 40.083507
1 H3 H7 229.0 483.0 47.412008
2 H5 H7 202.0 473.0 42.706131
In [5]:
h3_h7_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
    y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
    tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h7_wt_aa', 'h3_effect', 'h7_effect']
).properties(
    width=200,
    height=200,
    title='H3 vs. H7'
)

h3_h5_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h3_effect', title=['Effect on MDCK-SIAT1 entry', 'in H3 background']),
    y=alt.Y('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
    tooltip=['struct_site', 'mutant', 'h3_wt_aa', 'h5_wt_aa', 'h3_effect', 'h5_effect']
).properties(
    width=200,
    height=200,
    title='H3 vs. H5'
)

h5_h7_scatter = alt.Chart(mutation_effects).mark_circle(
    size=25, opacity=0.3, color='#767676'
).encode(
    x=alt.X('h5_effect', title=['Effect on 293T entry', 'in H5 background']),
    y=alt.Y('h7_effect', title=['Effect on 293-a2,6 entry', 'in H7 background']),
    tooltip=['struct_site', 'mutant', 'h5_wt_aa', 'h7_wt_aa', 'h5_effect', 'h7_effect']
).properties(
    width=200,
    height=200,
    title='H5 vs. H7'
)

h3_h7_scatter | h3_h5_scatter | h5_h7_scatter
Out[5]:
In [6]:
def scatter_and_density_plot(df, ha_x, ha_y, colors):
    r_value = df[f'avg_{ha_x}_effect'].corr(df[f'avg_{ha_y}_effect'])
    r_text = f"r = {r_value:.2f}"

    identity_line = alt.Chart(pd.DataFrame({'x': [-5, 0.3], 'y': [-5, 0.3]})).mark_line(
        strokeDash=[6, 6],
        color='black'
    ).encode(
        x='x',
        y='y'
    )

    df = df.assign(
        same_wildtype= lambda x: np.where(
            x[f'{ha_x}_wt_aa'] == x[f'{ha_y}_wt_aa'],
            'Amino acid conserved',
            'Amino acid changed'
        ),
    )

    scatter = alt.Chart(df).mark_circle(
        size=35, opacity=1, stroke='black', strokeWidth=0.5
    ).encode(
        x=alt.X(f'avg_{ha_x}_effect', title=['Mean effect on cell entry', f'in {ha_x.upper()} background']),
        y=alt.Y(f'avg_{ha_y}_effect', title=['Mean effect on cell entry', f'in {ha_y.upper()} background']),
        color=alt.Color(
            'same_wildtype:N', 
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
        tooltip=['struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', f'avg_{ha_x}_effect', f'avg_{ha_y}_effect']
    ).properties(
        width=175,
        height=175,
    )

    r_label = alt.Chart(pd.DataFrame({'text': [r_text]})).mark_text(
        align='left',
        baseline='top',
        fontSize=16,
        fontWeight='normal',
        color='black'
    ).encode(
        text='text:N',
        x=alt.value(5), 
        y=alt.value(5)
    )

    x_density = alt.Chart(df).transform_density(
        density=f'avg_{ha_x}_effect',
        bandwidth=0.3,
        groupby=['same_wildtype'],
        extent=[df[f'avg_{ha_x}_effect'].min(), df[f'avg_{ha_x}_effect'].max()],
        counts=True,
        steps=200
    ).mark_area(opacity=0.6, color='black', strokeWidth=1).encode(
        alt.X('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
        alt.Y('density:Q', title='Density').stack(None),
        color=alt.Color(
            'same_wildtype:N', 
            title=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
    ).properties(
        width=175,
        height=50
    )

    y_density = alt.Chart(df).transform_density(
        density=f'avg_{ha_y}_effect',
        bandwidth=0.3,
        groupby=['same_wildtype'],
        extent=[df[f'avg_{ha_y}_effect'].min(), df[f'avg_{ha_y}_effect'].max()],
        counts=True,
        steps=200
    ).mark_area(opacity=0.6, color='black', strokeWidth=1, orient='horizontal').encode(
        alt.Y('value:Q', axis=alt.Axis(labels=False, title=None, ticks=False)),
        alt.X('density:Q', title='Density').stack(None),
        color=alt.Color(
            'same_wildtype:N', 
            title=None,
            scale=alt.Scale(domain=list(colors.keys()), range=list(colors.values())),
        ),
    ).properties(
        width=50,
        height=175
    )

    marginal_plot = alt.vconcat(
        x_density,
        alt.hconcat(
            (scatter + identity_line + r_label),
            y_density
        )
    )
    return marginal_plot

colors = {
    'Amino acid changed' : '#5484AF',
    'Amino acid conserved' : '#E04948'
}
p1 = scatter_and_density_plot(site_effects, 'h3', 'h5', colors=colors)
p2 = scatter_and_density_plot(site_effects, 'h3', 'h7', colors=colors)
p3 = scatter_and_density_plot(site_effects, 'h5', 'h7', colors=colors)

p1 | p2 | p3
Out[6]:

Calculate Jensen-Shannon Divergence¶

In [7]:
def kl_divergence(p, q):
    return np.sum(p * np.log(p / q))

def compute_js_divergence_per_site(df, ha_x, ha_y, site_col="struct_site", min_mutations=15):
    """Compute JS divergence at each site and merge it back to the dataframe."""
    js_per_site = {}

    for site, group in df.groupby(site_col):
        valid = group.dropna(subset=[f'{ha_x}_effect', f'{ha_y}_effect'])
        js_div = np.nan

        if len(valid) >= min_mutations:
            p = np.exp(valid[f'{ha_x}_effect'].values)
            q = np.exp(valid[f'{ha_y}_effect'].values)

            p /= p.sum()
            q /= q.sum()

            m = 0.5 * (p + q)
            js_div = 0.5 * (kl_divergence(p, m) + kl_divergence(q, m))

        js_per_site[site] = js_div

    # Create a column with the JS divergence duplicated across each row at the same site
    df = df.copy()
    col_name = f"JS_{ha_x}_vs_{ha_y}"
    df[col_name] = df[site_col].map(js_per_site)

    return df

js_df_h3_h7 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h7', min_mutations=10)
js_df_h3_h5 = compute_js_divergence_per_site(mutation_effects, 'h3', 'h5', min_mutations=10)
js_df_h5_h7 = compute_js_divergence_per_site(mutation_effects, 'h5', 'h7', min_mutations=10)

Are epistatic shifts significant?¶

In [8]:
def compute_jsd_with_null(
    df,
    ha_x,
    ha_y,
    site_col="struct_site",
    min_mutations=15,
    n_bootstrap=1000,
    random_seed=42,
    jsd_threshold=0.02
):
    """
    Compute JS divergence with bootstrap null distribution for significance testing.

    The null distribution represents: "What JSD would I observe from measurement noise alone?"

    The null is generated by computing two separate null distributions:
    1. ha_x null: Sample ha_x twice with its measurement error, compute JSD
    2. ha_y null: Sample ha_y twice with its measurement error, compute JSD
    3. Take the mean of the two null distributions (balanced approach)

    This accounts for measurement noise from both experiments without assuming they have
    identical underlying effects. A significant result means the observed JSD is larger
    than what measurement noise alone could produce.

    Only sites with observed JSD > jsd_threshold are tested for significance.

    Parameters
    ----------
    df : pd.DataFrame
        Dataframe with mutation effects and effect_std columns
    ha_x, ha_y : str
        HA subtype names (e.g., 'h3', 'h5')
    site_col : str
        Column name for site identifier
    min_mutations : int
        Minimum number of mutations required at a site
    n_bootstrap : int
        Number of bootstrap iterations
    random_seed : int
        Random seed for reproducibility
    jsd_threshold : float
        Minimum JSD value for a site to be tested for significance.
        Sites with observed JSD <= threshold will have p_value = NaN.
        Default is 0.02.

    Returns
    -------
    pd.DataFrame
        DataFrame with columns:
        - struct_site: site identifier
        - JS_observed: observed JSD value
        - JS_null_mean: mean of null distribution (NaN if below threshold)
        - JS_null_std: standard deviation of null distribution (NaN if below threshold)
        - p_value: empirical p-value (NaN if below threshold)
        - n_mutations: number of mutations at site
        Sorted by struct_site using natural sorting.
    """
    np.random.seed(random_seed)

    def compute_jsd_vectorized(effects, std, n_bootstrap):
        """Vectorized computation of null JSD distribution."""
        n_mutations = len(effects)
        
        # Generate all bootstrap samples at once: shape (n_bootstrap, n_mutations)
        effects_1 = np.random.normal(
            loc=effects[np.newaxis, :],  # broadcast to (1, n_mutations)
            scale=std[np.newaxis, :],     # broadcast to (1, n_mutations)
            size=(n_bootstrap, n_mutations)
        )
        effects_2 = np.random.normal(
            loc=effects[np.newaxis, :],
            scale=std[np.newaxis, :],
            size=(n_bootstrap, n_mutations)
        )
        
        # Compute probabilities for all bootstraps at once
        p1 = np.exp(effects_1)
        p2 = np.exp(effects_2)
        
        # Normalize: divide each row by its sum
        p1 = p1 / p1.sum(axis=1, keepdims=True)
        p2 = p2 / p2.sum(axis=1, keepdims=True)
        
        # Compute mixture distribution
        m = 0.5 * (p1 + p2)
        
        # Compute KL divergences (vectorized)
        # KL(p||m) = sum(p * log(p/m))
        kl_p_m = np.sum(p1 * np.log(p1 / m), axis=1)
        kl_q_m = np.sum(p2 * np.log(p2 / m), axis=1)
        
        # JSD = 0.5 * (KL(p||m) + KL(q||m))
        jsd = 0.5 * (kl_p_m + kl_q_m)
        
        return jsd

    results = []

    for site, group in df.groupby(site_col):
        # Filter to valid mutations with both effects and stds
        valid = group.dropna(subset=[
            f'{ha_x}_effect', f'{ha_y}_effect',
            f'{ha_x}_effect_std', f'{ha_y}_effect_std'
        ])

        if len(valid) < min_mutations:
            continue

        # Get observed effects
        effects_x = valid[f'{ha_x}_effect'].values
        effects_y = valid[f'{ha_y}_effect'].values

        # Get standard deviations
        std_x = valid[f'{ha_x}_effect_std'].values
        std_y = valid[f'{ha_y}_effect_std'].values

        # Compute observed JSD between ha_x and ha_y
        p_obs = np.exp(effects_x)
        q_obs = np.exp(effects_y)
        p_obs /= p_obs.sum()
        q_obs /= q_obs.sum()
        m_obs = 0.5 * (p_obs + q_obs)
        jsd_obs = 0.5 * (kl_divergence(p_obs, m_obs) + kl_divergence(q_obs, m_obs))

        # Only compute null distribution if JSD exceeds threshold
        if jsd_obs <= jsd_threshold:
            results.append({
                'struct_site': site,
                'JS_observed': jsd_obs,
                'JS_null_mean': np.nan,
                'JS_null_std': np.nan,
                'p_value': np.nan,
                'n_mutations': len(valid),
                'null_distribution': None
            })
            continue

        # Vectorized bootstrap null distributions
        jsd_null_x = compute_jsd_vectorized(effects_x, std_x, n_bootstrap)
        jsd_null_y = compute_jsd_vectorized(effects_y, std_y, n_bootstrap)
        
        # Take the mean of the two nulls (balanced approach)
        jsd_null = (jsd_null_x + jsd_null_y) / 2

        # Compute empirical p-value (one-tailed test: is observed JSD greater than null?)
        p_value = np.mean(jsd_null >= jsd_obs)

        results.append({
            'struct_site': site,
            'JS_observed': jsd_obs,
            'JS_null_mean': jsd_null.mean(),
            'JS_null_std': jsd_null.std(),
            'p_value': p_value,
            'n_mutations': len(valid),
            'null_distribution': jsd_null  # Store for visualization
        })

    # Convert to DataFrame and sort by struct_site using natural sorting
    results_df = pd.DataFrame(results)
    results_df = results_df.sort_values('struct_site', key=natsort_keygen()).reset_index(drop=True)
    
    return results_df
In [9]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h3_h5 = compute_jsd_with_null(
    js_df_h3_h5,
    'h3', 'h5',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h3_h7 = compute_jsd_with_null(
    js_df_h3_h7,
    'h3', 'h7',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h5_h7 = compute_jsd_with_null(
    js_df_h5_h7,
    'h5', 'h7',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

# Apply multiple testing correction (Benjamini-Hochberg FDR)
# Only apply FDR to sites that were tested (non-NaN p-values)
from scipy.stats import false_discovery_control

def apply_fdr_with_threshold(df):
    """Apply FDR correction only to non-NaN p-values."""
    # Initialize q_value column with NaN
    df['q_value'] = np.nan
    
    # Get indices of non-NaN p-values
    tested_mask = df['p_value'].notna()
    
    if tested_mask.sum() > 0:
        # Apply FDR correction only to tested sites
        df.loc[tested_mask, 'q_value'] = false_discovery_control(df.loc[tested_mask, 'p_value'])
    
    return df

jsd_with_pvals_h3_h5 = apply_fdr_with_threshold(jsd_with_pvals_h3_h5)
jsd_with_pvals_h3_h7 = apply_fdr_with_threshold(jsd_with_pvals_h3_h7)
jsd_with_pvals_h5_h7 = apply_fdr_with_threshold(jsd_with_pvals_h5_h7)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H3 vs H5, q < 0.1):")
total_h3h5 = len(jsd_with_pvals_h3_h5)
sig_h3h5 = (jsd_with_pvals_h3_h5['q_value'] < 0.1).sum()
print(f"  {sig_h3h5} / {total_h3h5} sites ({sig_h3h5/total_h3h5:.2%})")

print("\nSignificant sites (H3 vs H7, q < 0.1):")
total_h3h7 = len(jsd_with_pvals_h3_h7)
sig_h3h7 = (jsd_with_pvals_h3_h7['q_value'] < 0.1).sum()
print(f"  {sig_h3h7} / {total_h3h7} sites ({sig_h3h7/total_h3h7:.2%})")

print("\nSignificant sites (H5 vs H7, q < 0.1):")
total_h5h7 = len(jsd_with_pvals_h5_h7)
sig_h5h7 = (jsd_with_pvals_h5_h7['q_value'] < 0.1).sum()
print(f"  {sig_h5h7} / {total_h5h7} sites ({sig_h5h7/total_h5h7:.2%})")
Significant sites (H3 vs H5, q < 0.1):
  270 / 468 sites (57.69%)

Significant sites (H3 vs H7, q < 0.1):
  253 / 467 sites (54.18%)

Significant sites (H5 vs H7, q < 0.1):
  212 / 432 sites (49.07%)
In [10]:
def plot_jsd(df, jsd_pvals_df, ha_x, ha_y, identity_df=None, alpha=0.1, only_lineplot=False, variant_selector=None): 
    """
    Plot JSD values with significance coloring.
    
    Parameters
    ----------
    df : pd.DataFrame
        Main dataframe with mutation effects
    jsd_pvals_df : pd.DataFrame
        DataFrame with JSD p-values and q-values from compute_jsd_with_null
    identity_df : pd.DataFrame
        DataFrame with sequence identity information
    ha_x, ha_y : str
        HA subtype names
    alpha : float
        Significance threshold for q-value (default 0.1)
    variant_selector : alt.selection_point, optional
        Shared selection object for synchronized highlighting across plots
    """
    if identity_df is not None:
        result = identity_df.query(
            f'ha_x=="{ha_x.upper()}" and ha_y=="{ha_y.upper()}"'
        )
        shared_aai = result['percent_identity'].values[0] if len(result) > 0 else None
    else:
        shared_aai = None

    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    )

    # Merge significance data with site-level JSD data
    site_jsd_df = df[[
        'struct_site', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
        f'JS_{ha_x}_vs_{ha_y}', f'rmsd_{ha_x}{ha_y}'
    ]].dropna().drop_duplicates()
    
    # Merge q-values
    site_jsd_df = site_jsd_df.merge(
        jsd_pvals_df[['struct_site', 'q_value']], 
        on='struct_site', 
        how='left'
    )
    
    # Add significance flag
    site_jsd_df = site_jsd_df.assign(
        significant=lambda x: x['q_value'] < alpha
    )

    # Use provided selector or create a new one
    if variant_selector is None:
        variant_selector = alt.selection_point(
            on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
        )

    sorted_sites = natsorted(df['struct_site'].unique())
    base = alt.Chart(site_jsd_df).encode(
        alt.X(
            "struct_site:O",
            sort=sorted_sites, 
            title='Site',
            axis=alt.Axis(
                labelAngle=0,
                values=['1', '50', '100', '150', '200', '250', '300', '350', '400', '450', '500'],
                tickCount=11,
            )
        ),
        alt.Y(
            f'JS_{ha_x}_vs_{ha_y}:Q', 
            title=['Divergence in amino-acid', 'preferences'],
            axis=alt.Axis(
                grid=False
            ),
            scale=alt.Scale(domain=[0, 0.7])
        ),
        tooltip=[
            'struct_site', 
            f'{ha_x}_wt_aa', 
            f'{ha_y}_wt_aa', 
            alt.Tooltip(f'JS_{ha_x}_vs_{ha_y}', format='.4f'),
            alt.Tooltip(f'rmsd_{ha_x}{ha_y}', format='.2f'),
            alt.Tooltip('q_value', format='.4f'),
            'significant'
        ],
    ).properties(
        width=800,
        height=150
    )

    line = base.mark_line(opacity=0.5, stroke='#999999', size=1)
    
    # Points layer with conditional formatting based on hover and click
    points = base.mark_circle(filled=True).encode(
        size=alt.condition(
            variant_selector,
            alt.value(75),  # when selected
            alt.value(40)  # default
        ),
        color=alt.Color(
            'significant:N',
            title=['Significant', f'(FDR < {alpha})'],
            scale=alt.Scale(domain=[True, False], range=['#E15759', '#BAB0AC']),
            legend=alt.Legend(
                titleFontSize=14,
                labelFontSize=12
            )
        ),
        stroke=alt.condition(
            variant_selector,
            alt.value('black'),
            alt.value(None)
        ),
        strokeWidth=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0)
        ),
        opacity=alt.condition(
            variant_selector,
            alt.value(1),
            alt.value(0.75)
        )
    ).add_params(
        variant_selector
    )

    # Correlation between cell entry effects plot
    # Filter based on hover (only if nothing clicked) or click
    base_corr_chart = (alt.Chart(df)
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        )
        .transform_filter(
            variant_selector
        )
        .properties(
            height=150,
            width=150,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = vline + hline + base_corr_chart

    # density plot
    density = alt.Chart(
        site_jsd_df
    ).transform_density(
        density=f'JS_{ha_x}_vs_{ha_y}',
        bandwidth=0.02,
        extent=[0,1],
        counts=True,
        steps=200
    ).mark_area(opacity=1, color='#CCEBC5', stroke='black', strokeWidth=1).encode(
        alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
        alt.Y('density:Q', title='Density').stack(None),
    ).properties(
        width=200,
        height=60
    )

    if shared_aai is not None:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()} ({shared_aai:.1f}% Amino Acid Identity)'
    else:
        title_text = f'{ha_x.upper()} vs. {ha_y.upper()}'

    # combine the bar and heatmaps
    if only_lineplot is False:
        combined_chart = alt.vconcat(
            (line + points), corr_chart, density
        ).resolve_scale(
            y='independent', 
            x='independent', 
            color='independent'
        )
    else:
        combined_chart = line + points
    
    combined_chart = combined_chart.properties(
        title=alt.Title(title_text, 
        offset=0,
        fontSize=18,
        #subtitle=['Hover over sites to see mutation effects. Click to lock selection (double-click to clear).'],
        subtitleFontSize=16,
        anchor='middle'
        )
    )

    return combined_chart

chart = plot_jsd(
    js_df_h3_h5,
    jsd_with_pvals_h3_h5,
    'h3', 'h5',
    seq_identity
)
chart.display()
In [11]:
chart = plot_jsd(
    js_df_h3_h7,
    jsd_with_pvals_h3_h7,
    'h3', 'h7',
    seq_identity
)
chart.display()
In [12]:
chart = plot_jsd(
    js_df_h5_h7,
    jsd_with_pvals_h5_h7,
    'h5', 'h7',
    seq_identity
)
chart.display()
In [13]:
# Create a shared selection for synchronized hovering across all plots
shared_selection = alt.selection_point(
    on="mouseover", empty=False, nearest=True, fields=["struct_site"], value=1
)

combined_interactive_lineplots = (
    plot_jsd(
        js_df_h3_h5, jsd_with_pvals_h3_h5,
        'h3', 'h5', seq_identity, only_lineplot=True,
        variant_selector=shared_selection
    ) & plot_jsd(
        js_df_h5_h7, jsd_with_pvals_h5_h7,
        'h5', 'h7', seq_identity, only_lineplot=True,
        variant_selector=shared_selection
    ) & plot_jsd(
        js_df_h3_h7, jsd_with_pvals_h3_h7,
        'h3', 'h7', seq_identity, only_lineplot=True,
        variant_selector=shared_selection
    )
)
combined_interactive_lineplots.save('combined_interactive_lineplots.html')
combined_interactive_lineplots.display()
In [14]:
js_df_h3_h5[[
    'struct_site', 'h3_wt_aa', 'h5_wt_aa', 'rmsd_h3h5', '4o5n_aa_RSA', 'JS_h3_vs_h5'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h3_h5_divergence.csv', index=False
)

js_df_h3_h7[[
    'struct_site', 'h3_wt_aa', 'h7_wt_aa', 'rmsd_h3h7', '4o5n_aa_RSA', 'JS_h3_vs_h7'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h3_h7_divergence.csv', index=False
)

js_df_h5_h7[[
    'struct_site', 'h5_wt_aa', 'h7_wt_aa', 'rmsd_h5h7', '4o5n_aa_RSA', 'JS_h5_vs_h7'
]].drop_duplicates().reset_index(drop=True).to_csv(
    '../results/divergence/h5_h7_divergence.csv', index=False
)

H7 2'6 vs. H7 2'3¶

In [15]:
def read_and_filter_data(
    path, 
    effect_std_filter=2,
    times_seen_filter=2,
    n_selections_filter=2,
    clip_effect=-5 
):
    print(f'Reading data from {path}')
    print(
        f"Filtering for:\n"
        f"  effect_std <= {effect_std_filter}\n"
        f"  times_seen >= {times_seen_filter}\n"
        f"  n_selections >= {n_selections_filter}"
    )
    print(f"Clipping effect values at {clip_effect}")

    df = pd.read_csv(path).query(
        'effect_std <= @effect_std_filter and \
        times_seen >= @times_seen_filter and \
        n_selections >= @n_selections_filter'
    ).query(
        'mutant not in ["*", "-"]' # don't want stop codons/indels
    )

    df['site'] = df['site'].astype(str)
    df['effect'] = df['effect'].clip(-5)

    df = pd.concat([
        df,
        df[['site', 'wildtype']].drop_duplicates().assign(
            mutant=lambda x: x['wildtype'],
            effect=0.0,
            effect_std=0.0,
            times_seen=np.nan,
            n_selections=np.nan
        ) # add wildtype sites with zero effect
    ], ignore_index=True).sort_values(['site', 'mutant']).reset_index(drop=True)
    
    return df
In [16]:
h7_23_df = read_and_filter_data('../data/cell_entry_effects/293_2-3_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h7_2-3_wt_aa',
        'mutant': 'mutant',
        'effect': 'h7_2-3_effect',
        'effect_std': 'h7_2-3_effect_std'
    }
)
h7_26_df = read_and_filter_data('../data/cell_entry_effects/293_2-6_entry_func_effects.csv')[[
    'site', 'wildtype', 'mutant', 'effect', 'effect_std'
]].rename(
    columns={
        'site': 'struct_site',
        'wildtype': 'h7_2-6_wt_aa',
        'mutant': 'mutant',
        'effect': 'h7_2-6_effect',
        'effect_std': 'h7_2-6_effect_std'
    }
)

h7_23_26_df = pd.merge(
    h7_23_df,
    h7_26_df,
    left_on=['struct_site', 'h7_2-3_wt_aa', 'mutant'],
    right_on=['struct_site', 'h7_2-6_wt_aa', 'mutant'],
).assign(
    **{'rmsd_h7_2-3h7_2-6': 0}
)

h7_23_26_df.head()
Reading data from ../data/cell_entry_effects/293_2-3_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Reading data from ../data/cell_entry_effects/293_2-6_entry_func_effects.csv
Filtering for:
  effect_std <= 2
  times_seen >= 2
  n_selections >= 2
Clipping effect values at -5
Out[16]:
struct_site h7_2-3_wt_aa mutant h7_2-3_effect h7_2-3_effect_std h7_2-6_wt_aa h7_2-6_effect h7_2-6_effect_std rmsd_h7_2-3h7_2-6
0 100 G A -0.00515 0.85400 G -1.276 0.719 0
1 100 G C -3.90900 0.01169 G -4.422 0.000 0
2 100 G D -4.78700 0.00000 G -4.936 0.000 0
3 100 G G 0.00000 0.00000 G 0.000 0.000 0
4 100 G H -4.63900 0.00000 G -4.796 0.000 0
In [17]:
js_df_h7_23_26 = compute_js_divergence_per_site(h7_23_26_df, 'h7_2-3', 'h7_2-6', min_mutations=10)
In [18]:
# Compute JSD with null distributions for each comparison
jsd_with_pvals_h7_23_26 = compute_jsd_with_null(
    js_df_h7_23_26,
    'h7_2-3', 'h7_2-6',
    min_mutations=10,
    n_bootstrap=1000,
    jsd_threshold=0.02
)

jsd_with_pvals_h7_23_26 = apply_fdr_with_threshold(jsd_with_pvals_h7_23_26)

# Report significant sites as fractions (out of ALL sites with JSD measurements)
print("Significant sites (H7 2-3 vs H7 2-6, q < 0.1):")
total_h7_23_26 = len(jsd_with_pvals_h7_23_26)
sig_h7_23_26 = (jsd_with_pvals_h7_23_26['q_value'] < 0.1).sum()
print(f"  {sig_h7_23_26} / {total_h7_23_26} sites ({sig_h7_23_26/total_h7_23_26:.2%})")
Significant sites (H7 2-3 vs H7 2-6, q < 0.1):
  0 / 492 sites (0.00%)
In [19]:
chart = plot_jsd(
    js_df_h7_23_26,
    jsd_with_pvals_h7_23_26,
    'h7_2-3', 'h7_2-6'
)
chart.display()
In [20]:
def plot_ridgeline_density(dfs_dict, x_col_template='JS_{ha_x}_vs_{ha_y}', 
                           bandwidth=0.02, extent=[0,0.65], 
                           colors=None, width=200, height=400,
                           overlap=2.5, label_mapping=None):
    """
    Plot ridgeline (joyplot) density plots for multiple dataframes.
    
    Parameters:
    -----------
    dfs_dict : dict
        Dictionary where keys are comparison labels (e.g., 'h3-h5', 'h3-h7') 
        and values are tuples of (df, ha_x, ha_y)
        Example: {'h3-h5': (js_df_h3_h5, 'h3', 'h5'), 
                  'h3-h7': (js_df_h3_h7, 'h3', 'h7')}
    x_col_template : str
        Template for column name with {ha_x} and {ha_y} placeholders
    bandwidth : float
        Bandwidth for density estimation
    extent : list
        [min, max] for density calculation
    colors : list or None
        List of colors for each comparison. If None, uses default color scheme
    width, height : int
        Dimensions of the plot
    overlap : float
        How much the ridges overlap (higher = more overlap)
    
    Returns:
    --------
    alt.Chart : Ridgeline density plot
    """
    import pandas as pd
    import altair as alt
    
    # Default color scheme if none provided
    if colors is None:
        colors = ['#8DD3C7', '#FFFFB3', '#BEBADA', '#FB8072', '#80B1D3', '#FDB462']
    
    # Combine all dataframes with a comparison label
    combined_data = []
    for i, (comparison, (df, ha_x, ha_y)) in enumerate(dfs_dict.items()):
        col_name = x_col_template.format(ha_x=ha_x, ha_y=ha_y)
        temp_df = df[[col_name]].copy()
        temp_df['comparison'] = comparison
        temp_df['value'] = temp_df[col_name]
        combined_data.append(temp_df[['value', 'comparison']])
    
    combined_df = pd.concat(combined_data, ignore_index=True)
    
    if label_mapping is not None:
            combined_df['comparison'] = combined_df['comparison'].map(label_mapping)
        
    # Calculate step size for ridgeline spacing
    step = height / (len(dfs_dict) * overlap)
    
    # Create the ridgeline plot
    ridgeline = alt.Chart(combined_df).transform_density(
        density='value',
        bandwidth=bandwidth,
        extent=extent,
        groupby=['comparison'],
        steps=200
    ).transform_calculate(
        # Offset each comparison vertically based on its order
        yvalue='datum.density'
    ).mark_area(
        opacity=1,
        stroke='black',
        strokeWidth=1,
        interpolate='monotone'
    ).encode(
        alt.X('value:Q', title=['Divergence in amino-acid', 'preferences']),
        alt.Y('density:Q', 
              title='Density',
              axis=None),
        alt.Row('comparison:N',
                title=None,
                header=alt.Header(labelAngle=0, labelAlign='left')),
        alt.Fill('comparison:N',
                 legend=None,
                 scale=alt.Scale(range=colors[:len(dfs_dict)]))
    ).properties(
        width=width,
        height=step,
        bounds='flush'
    ).configure_facet(
        spacing=-(step * (overlap - 1))
    ).configure_view(
        stroke=None
    ).configure_header(
    labelFontSize=14
)
    
    return ridgeline

# Example usage:
dfs_to_plot = {
    'h3-h5': (js_df_h3_h5, 'h3', 'h5'),
    'h3-h7': (js_df_h3_h7, 'h3', 'h7'),
    'h5-h7': (js_df_h5_h7, 'h5', 'h7'),
    'h7_2-3-h7_2-6': (js_df_h7_23_26, 'h7_2-3', 'h7_2-6')
}

plot_ridgeline_density(
    dfs_to_plot, 
    label_mapping={
        'h3-h5': 'H3 vs. H5',
        'h3-h7': 'H3 vs. H7',
        'h5-h7': 'H5 vs. H7',
        'h7_2-3-h7_2-6': ['H7 (a2,3) vs.', 'H7 (a2,6)']
    }
).display()

Examples of mutation effect correlations¶

In [21]:
def plot_correlation(df, ha_x, ha_y, site, decimal_places=2):
    amino_acid_classification = {
        'F': 'Aromatic', 'Y': 'Aromatic', 'W': 'Aromatic',
        'N': 'Hydrophilic', 'Q': 'Hydrophilic', 'S': 'Hydrophilic', 'T': 'Hydrophilic',
        'A': 'Hydrophobic', 'V': 'Hydrophobic', 'I': 'Hydrophobic', 'L': 'Hydrophobic', 'M': 'Hydrophobic',
        'D': 'Negative', 'E': 'Negative',
        'R': 'Positive', 'H': 'Positive', 'K': 'Positive',
        'C': 'Special', 'G': 'Special', 'P': 'Special'
    }
    df['struct_site'] = df['struct_site'].astype(str)

    df = df.assign(
        mutant_type=lambda x: x['mutant'].map(amino_acid_classification)
    ).query(f'struct_site == "{site}"')

    jsd = df[f'JS_{ha_x}_vs_{ha_y}'].unique()[0]

    base_corr_chart = (alt.Chart(df.query(f'struct_site == "{site}"'))
        .mark_text(size=20)
        .encode(
            alt.X(
                f"{ha_x}_effect", 
                title=["Effect on cell entry", f"in {ha_x.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Y(
                f"{ha_y}_effect", 
                title=["Effect on cell entry", f"in {ha_y.upper()}"], 
                scale=alt.Scale(domain=[-6,1.5])
            ),
            alt.Text('mutant'),
            alt.Color('mutant_type',
                    title='Mutant type',
                    scale=alt.Scale(
                        domain=['Aromatic', 'Hydrophilic', 'Hydrophobic','Negative', 'Positive', 'Special'],
                        range=["#4e79a7","#f28e2c","#e15759","#76b7b2","#59a14f","#edc949"]
                    ),
                    legend=alt.Legend(
                        titleFontSize=16,
                        labelFontSize=13
                    )
            ),
            tooltip=['struct_site', 'mutant', f'{ha_x}_wt_aa', f'{ha_y}_wt_aa', 
                     f'{ha_x}_effect', f'{ha_x}_effect_std', 
                     f'{ha_y}_effect', f'{ha_y}_effect_std',
                    f'JS_{ha_x}_vs_{ha_y}'],  
        ).properties(
            height=125,
            width=125,
        )
    )

    # Vertical line at x = 0
    vline = alt.Chart(pd.DataFrame({'x': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(x='x:Q')
    
    # Horizontal line at y = 0
    hline = alt.Chart(pd.DataFrame({'y': [0]})).mark_rule(color='gray',opacity=0.5,strokeDash=[2,4]).encode(y='y:Q')
    
    corr_chart = (vline + hline + base_corr_chart).properties(
        title=alt.Title([f'Site {site}', f'Divergence = {jsd:.{decimal_places}f}'], 
        offset=0,
        fontSize=16,
        anchor='middle'
        )
    )
    return corr_chart
In [22]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='97', decimal_places=3) |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='198') |
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='241')
).display()
In [23]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='86') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='86') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='86')
).display()
In [24]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='173') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='173') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='173')
).display()
In [25]:
(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='178') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='178') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='178')
).display()

(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='123') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='123') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='123')
).display()

(
    plot_correlation(js_df_h3_h5, 'h3', 'h5', site='176') |
    plot_correlation(js_df_h3_h7, 'h3', 'h7', site='176') |
    plot_correlation(js_df_h5_h7, 'h5', 'h7', site='176')
).display()

# H3 forms H bonds at 178, 123, 176, and 211. 
# H5 and H7 do not form any H bonds in this region, and therefore tolerate many more amino acids.